Skip to content

Conversation

@codeflash-ai
Copy link
Contributor

@codeflash-ai codeflash-ai bot commented Nov 5, 2025

⚡️ This pull request contains optimizations for PR #867

If you approve this dependent PR, these changes will be merged into the original PR branch inspect-signature-issue.

This PR will be automatically closed if the original PR is merged.


📄 31% (0.31x) speedup for ImportAnalyzer.visit_Call in codeflash/discovery/discover_unit_tests.py

⏱️ Runtime : 13.0 milliseconds 9.94 milliseconds (best of 41 runs)

📝 Explanation and details

The optimization replaces the standard ast.NodeVisitor.generic_visit call with a custom _fast_generic_visit method that inlines the AST traversal logic, eliminating method resolution overhead and adding more aggressive early-exit checks.

Key Performance Improvements:

  1. Eliminated Method Resolution Overhead: The original code called ast.NodeVisitor.generic_visit(self, node) which incurs method lookup and dispatch costs. The optimized version inlines this logic directly, avoiding the base class method call entirely.

  2. More Frequent Early Exit Checks: The new _fast_generic_visit checks self.found_any_target_function at multiple points during traversal (before processing lists and individual AST nodes), allowing faster short-circuiting when a target function is found.

  3. Optimized Attribute Access: The optimization uses direct getattr calls and caches method lookups (getattr(self, 'visit_' + item.__class__.__name__, None)) to reduce repeated attribute resolution.

Performance Impact by Test Case:

  • Large-scale tests show the biggest gains (27-35% faster) because they process many AST nodes where the traversal overhead compounds
  • Basic tests with fewer nodes show moderate improvements (9-20% faster)
  • Edge cases with complex nesting benefit from the more frequent early-exit checks

The line profiler shows the optimization reduces time spent in generic_visit from 144.2ms to 107.9ms (25% improvement), with the overall visit_Call method improving from 287.5ms to 210.3ms. This optimization is particularly valuable for AST analysis tools that process large codebases, as the traversal overhead reduction scales with the size and complexity of the analyzed code.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 5073 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
import ast

# imports
import pytest
from codeflash.discovery.discover_unit_tests import ImportAnalyzer

# unit tests

# Helper to parse code and get the Call node(s)
def get_call_nodes(code: str):
    tree = ast.parse(code)
    return [node for node in ast.walk(tree) if isinstance(node, ast.Call)]

# Basic Test Cases

def test_visit_call_basic_import_call_sets_flag():
    # Test that __import__ sets has_dynamic_imports to True
    code = "__import__('os')"
    calls = get_call_nodes(code)
    analyzer = ImportAnalyzer(set())
    analyzer.visit_Call(calls[0]) # 10.6μs -> 9.66μs (9.44% faster)

def test_visit_call_basic_non_import_call_does_not_set_flag():
    # Test that calling a normal function does not set has_dynamic_imports
    code = "foo()"
    calls = get_call_nodes(code)
    analyzer = ImportAnalyzer(set())
    analyzer.visit_Call(calls[0]) # 5.61μs -> 4.86μs (15.5% faster)

def test_visit_call_basic_multiple_calls_only_import_sets_flag():
    # Test that only __import__ sets flag, not other calls
    code = "foo(); __import__('os'); bar()"
    calls = get_call_nodes(code)
    analyzer = ImportAnalyzer(set())
    for call in calls:
        analyzer.visit_Call(call) # 17.5μs -> 14.6μs (20.0% faster)

def test_visit_call_basic_flag_not_set_if_found_any_target_function():
    # If found_any_target_function is True, __import__ should not set flag
    code = "__import__('os')"
    calls = get_call_nodes(code)
    analyzer = ImportAnalyzer(set())
    analyzer.found_any_target_function = True
    analyzer.visit_Call(calls[0]) # 421ns -> 441ns (4.54% slower)

# Edge Test Cases


def test_visit_call_edge_import_with_kwargs():
    # __import__ called with keyword arguments
    code = "__import__('os', globals(), locals(), [], 0)"
    calls = get_call_nodes(code)
    analyzer = ImportAnalyzer(set())
    analyzer.visit_Call(calls[0]) # 22.4μs -> 19.4μs (15.3% faster)

def test_visit_call_edge_nested_calls():
    # __import__ inside another call
    code = "foo(__import__('os'))"
    calls = get_call_nodes(code)
    # There are two calls: __import__ and foo
    analyzer = ImportAnalyzer(set())
    # Visit both calls
    for call in calls:
        analyzer.visit_Call(call) # 20.2μs -> 17.0μs (19.0% faster)

def test_visit_call_edge_non_name_func():
    # __import__ called via getattr, should not set flag
    code = "getattr(__builtins__, '__import__')('os')"
    calls = get_call_nodes(code)
    analyzer = ImportAnalyzer(set())
    for call in calls:
        analyzer.visit_Call(call) # 23.6μs -> 19.9μs (18.3% faster)

def test_visit_call_edge_func_is_attribute():
    # __import__ called as an attribute, should not set flag
    code = "builtins.__import__('os')"
    calls = get_call_nodes(code)
    analyzer = ImportAnalyzer(set())
    analyzer.visit_Call(calls[0]) # 12.6μs -> 12.8μs (1.48% slower)

def test_visit_call_edge_func_is_lambda():
    # Call a lambda named __import__, should not set flag
    code = "(lambda x: x)('os')"
    calls = get_call_nodes(code)
    analyzer = ImportAnalyzer(set())
    analyzer.visit_Call(calls[0]) # 14.7μs -> 12.2μs (20.1% faster)

def test_visit_call_edge_func_is_name_but_not_import():
    # Call a function named _import (not __import__), should not set flag
    code = "_import('os')"
    calls = get_call_nodes(code)
    analyzer = ImportAnalyzer(set())
    analyzer.visit_Call(calls[0]) # 9.59μs -> 8.55μs (12.2% faster)

def test_visit_call_edge_no_calls():
    # No calls, flag should remain False
    code = "x = 1"
    calls = get_call_nodes(code)
    analyzer = ImportAnalyzer(set())
    for call in calls:
        analyzer.visit_Call(call)

# Large Scale Test Cases

def test_visit_call_large_many_calls_one_import():
    # Many calls, only one __import__ call
    code = "\n".join([f"foo{i}()" for i in range(999)] + ["__import__('os')"])
    calls = get_call_nodes(code)
    analyzer = ImportAnalyzer(set())
    for call in calls:
        analyzer.visit_Call(call) # 2.66ms -> 1.97ms (35.0% faster)

def test_visit_call_large_many_calls_no_import():
    # Many calls, none is __import__
    code = "\n".join([f"foo{i}()" for i in range(1000)])
    calls = get_call_nodes(code)
    analyzer = ImportAnalyzer(set())
    for call in calls:
        analyzer.visit_Call(call) # 2.64ms -> 1.96ms (34.1% faster)

def test_visit_call_large_many_import_calls():
    # Many __import__ calls
    code = "\n".join([f"__import__('mod{i}')" for i in range(1000)])
    calls = get_call_nodes(code)
    analyzer = ImportAnalyzer(set())
    for call in calls:
        analyzer.visit_Call(call) # 4.58ms -> 3.59ms (27.6% faster)

def test_visit_call_large_found_any_target_function_stops_processing():
    # found_any_target_function is True, even with many __import__ calls, flag should not change
    code = "\n".join([f"__import__('mod{i}')" for i in range(1000)])
    calls = get_call_nodes(code)
    analyzer = ImportAnalyzer(set())
    analyzer.found_any_target_function = True
    for call in calls:
        analyzer.visit_Call(call) # 156μs -> 152μs (2.31% faster)

def test_visit_call_large_mixed_calls_and_imports():
    # Mix of normal and __import__ calls
    code = "\n".join([f"foo{i}()" if i % 10 else "__import__('mod{i}')" for i in range(1000)])
    calls = get_call_nodes(code)
    analyzer = ImportAnalyzer(set())
    for call in calls:
        analyzer.visit_Call(call) # 2.87ms -> 2.15ms (33.5% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
import ast

# imports
import pytest  # used for our unit tests
from codeflash.discovery.discover_unit_tests import ImportAnalyzer

# unit tests

# Helper to run visit_Call on a given code string and return analyzer
def run_visit_call_on_code(code: str, function_names_to_find=None):
    if function_names_to_find is None:
        function_names_to_find = {"foo", "bar.baz"}
    tree = ast.parse(code)
    analyzer = ImportAnalyzer(function_names_to_find)
    # Find all Call nodes and run visit_Call on each
    for node in ast.walk(tree):
        if isinstance(node, ast.Call):
            analyzer.visit_Call(node)
    return analyzer

# ------------------ Basic Test Cases ------------------

def test_visit_call_basic_no_calls():
    """Test with code that has no function calls at all."""
    code = "x = 1 + 2"
    analyzer = run_visit_call_on_code(code)

def test_visit_call_basic_regular_call():
    """Test with a regular function call that is not __import__."""
    code = "foo()"
    analyzer = run_visit_call_on_code(code)

def test_visit_call_basic_multiple_calls():
    """Test with multiple calls, none being __import__."""
    code = "foo(); bar(); baz(1, 2)"
    analyzer = run_visit_call_on_code(code)

def test_visit_call_basic_import_call():
    """Test with __import__ call."""
    code = "__import__('os')"
    analyzer = run_visit_call_on_code(code)

def test_visit_call_basic_mixed_calls():
    """Test with both __import__ and other calls."""
    code = "foo(); __import__('sys'); bar()"
    analyzer = run_visit_call_on_code(code)

# ------------------ Edge Test Cases ------------------

def test_visit_call_edge_nested_import_call():
    """Test with __import__ inside another call."""
    code = "wrapper(__import__('math'))"
    analyzer = run_visit_call_on_code(code)

def test_visit_call_edge_import_as_attribute():
    """Test with __import__ as an attribute (should NOT match)."""
    code = "mod.__import__('os')"
    analyzer = run_visit_call_on_code(code)

def test_visit_call_edge_import_with_args():
    """Test __import__ with multiple arguments."""
    code = "__import__('os', globals(), locals(), [], 0)"
    analyzer = run_visit_call_on_code(code)

def test_visit_call_edge_import_in_lambda():
    """Test __import__ inside a lambda."""
    code = "(lambda: __import__('os'))()"
    analyzer = run_visit_call_on_code(code)

def test_visit_call_edge_import_in_comprehension():
    """Test __import__ inside a list comprehension."""
    code = "[__import__(mod) for mod in ['os', 'sys']]"
    analyzer = run_visit_call_on_code(code)

def test_visit_call_edge_import_in_class_method():
    """Test __import__ inside a class method."""
    code = '''
class A:
    def m(self):
        return __import__("os")
'''
    analyzer = run_visit_call_on_code(code)

def test_visit_call_edge_import_in_try_except():
    """Test __import__ inside a try/except."""
    code = '''
try:
    __import__("os")
except ImportError:
    pass
'''
    analyzer = run_visit_call_on_code(code)

def test_visit_call_edge_import_in_with():
    """Test __import__ inside a with statement."""
    code = '''
with open("file.txt") as f:
    __import__("os")
'''
    analyzer = run_visit_call_on_code(code)

def test_visit_call_edge_import_in_function_default_arg():
    """Test __import__ as a default argument value."""
    code = '''
def f(x=__import__("os")):
    pass
'''
    analyzer = run_visit_call_on_code(code)

def test_visit_call_edge_import_in_decorator():
    """Test __import__ used as a decorator."""
    code = '''
@__import__("os")
def f():
    pass
'''
    analyzer = run_visit_call_on_code(code)

def test_visit_call_edge_import_in_generator():
    """Test __import__ inside a generator expression."""
    code = "(i for i in range(__import__('os').getpid()))"
    analyzer = run_visit_call_on_code(code)

def test_visit_call_edge_import_in_conditional():
    """Test __import__ inside an if statement."""
    code = '''
if True:
    __import__("os")
'''
    analyzer = run_visit_call_on_code(code)

def test_visit_call_edge_import_in_while():
    """Test __import__ inside a while loop."""
    code = '''
while False:
    __import__("os")
'''
    analyzer = run_visit_call_on_code(code)

def test_visit_call_edge_import_in_for():
    """Test __import__ inside a for loop."""
    code = '''
for i in range(1):
    __import__("os")
'''
    analyzer = run_visit_call_on_code(code)

def test_visit_call_edge_import_in_augassign():
    """Test __import__ in an augmented assignment."""
    code = "x += __import__('os')"
    analyzer = run_visit_call_on_code(code)

def test_visit_call_edge_import_in_assert():
    """Test __import__ in an assert statement."""
    code = "assert __import__('os')"
    analyzer = run_visit_call_on_code(code)

def test_visit_call_edge_import_in_return():
    """Test __import__ in a return statement."""
    code = "def f(): return __import__('os')"
    analyzer = run_visit_call_on_code(code)

def test_visit_call_edge_import_in_yield():
    """Test __import__ in a yield statement."""
    code = "def f(): yield __import__('os')"
    analyzer = run_visit_call_on_code(code)

def test_visit_call_edge_import_in_raise():
    """Test __import__ in a raise statement."""
    code = "raise __import__('os')"
    analyzer = run_visit_call_on_code(code)

def test_visit_call_edge_import_in_expr():
    """Test __import__ as a standalone expression."""
    code = "__import__('os')"
    analyzer = run_visit_call_on_code(code)

def test_visit_call_edge_import_in_complex_expr():
    """Test __import__ inside a complex expression."""
    code = "result = foo(__import__('os'), bar(1))"
    analyzer = run_visit_call_on_code(code)


def test_visit_call_edge_import_with_keyword_args():
    """Test __import__ with keyword arguments."""
    code = "__import__(name='os')"
    analyzer = run_visit_call_on_code(code)

def test_visit_call_edge_import_with_star_args():
    """Test __import__ with *args and **kwargs."""
    code = "__import__(*args, **kwargs)"
    analyzer = run_visit_call_on_code(code)

def test_visit_call_edge_import_with_nonstring_arg():
    """Test __import__ with a non-string argument."""
    code = "__import__(42)"
    analyzer = run_visit_call_on_code(code)

def test_visit_call_edge_import_with_no_args():
    """Test __import__ with no arguments (invalid, but should be detected)."""
    code = "__import__()"
    analyzer = run_visit_call_on_code(code)

def test_visit_call_edge_import_in_docstring():
    """Test __import__ in a docstring (should NOT be detected)."""
    code = '"""__import__("os")"""'
    analyzer = run_visit_call_on_code(code)

def test_visit_call_edge_import_in_comment():
    """Test __import__ in a comment (should NOT be detected)."""
    code = "# __import__('os')"
    analyzer = run_visit_call_on_code(code)

def test_visit_call_edge_import_in_string_literal():
    """Test __import__ in a string literal (should NOT be detected)."""
    code = 'x = "__import__(\'os\')"'
    analyzer = run_visit_call_on_code(code)

# ------------------ Large Scale Test Cases ------------------

def test_visit_call_large_many_calls():
    """Test with a large number of function calls, only one being __import__."""
    code = "\n".join([f"foo{i}()" for i in range(999)]) + "\n__import__('os')"
    analyzer = run_visit_call_on_code(code)

def test_visit_call_large_all_regular_calls():
    """Test with a large number of regular calls, none being __import__."""
    code = "\n".join([f"foo{i}()" for i in range(1000)])
    analyzer = run_visit_call_on_code(code)

def test_visit_call_large_all_import_calls():
    """Test with a large number of __import__ calls."""
    code = "\n".join([f"__import__('os{i}')" for i in range(1000)])
    analyzer = run_visit_call_on_code(code)

def test_visit_call_large_mixed_calls():
    """Test with a mix of __import__ and regular calls."""
    code = "\n".join([f"foo{i}()" if i % 2 == 0 else "__import__('os{i}')" for i in range(1000)])
    analyzer = run_visit_call_on_code(code)

def test_visit_call_large_import_in_functions():
    """Test with many functions, each calling __import__."""
    code = "\n".join([f"def f{i}():\n    __import__('os{i}')\n" for i in range(100)])
    analyzer = run_visit_call_on_code(code)

def test_visit_call_large_import_in_classes():
    """Test with many classes, each with a method calling __import__."""
    code = "\n".join([f"class C{i}:\n    def m(self):\n        __import__('os{i}')\n" for i in range(100)])
    analyzer = run_visit_call_on_code(code)

def test_visit_call_large_nested_calls():
    """Test with deeply nested calls including __import__."""
    code = "foo(" * 10 + "__import__('os')" + ")" * 10
    analyzer = run_visit_call_on_code(code)

def test_visit_call_large_import_in_comprehensions():
    """Test with __import__ inside many comprehensions."""
    code = "\n".join([f"[__import__('os{i}') for j in range(10)]" for i in range(100)])
    analyzer = run_visit_call_on_code(code)

def test_visit_call_large_import_in_decorators():
    """Test with many functions decorated with __import__."""
    code = "\n".join([f"@__import__('os{i}')\ndef f{i}(): pass" for i in range(100)])
    analyzer = run_visit_call_on_code(code)

def test_visit_call_large_import_in_augassignments():
    """Test with many augmented assignments using __import__."""
    code = "\n".join([f"x{i} += __import__('os{i}')" for i in range(1000)])
    analyzer = run_visit_call_on_code(code)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-pr867-2025-11-05T08.40.37 and push.

Codeflash Static Badge

The optimization replaces the standard `ast.NodeVisitor.generic_visit` call with a custom `_fast_generic_visit` method that inlines the AST traversal logic, eliminating method resolution overhead and adding more aggressive early-exit checks.

**Key Performance Improvements:**

1. **Eliminated Method Resolution Overhead**: The original code called `ast.NodeVisitor.generic_visit(self, node)` which incurs method lookup and dispatch costs. The optimized version inlines this logic directly, avoiding the base class method call entirely.

2. **More Frequent Early Exit Checks**: The new `_fast_generic_visit` checks `self.found_any_target_function` at multiple points during traversal (before processing lists and individual AST nodes), allowing faster short-circuiting when a target function is found.

3. **Optimized Attribute Access**: The optimization uses direct `getattr` calls and caches method lookups (`getattr(self, 'visit_' + item.__class__.__name__, None)`) to reduce repeated attribute resolution.

**Performance Impact by Test Case:**
- **Large-scale tests** show the biggest gains (27-35% faster) because they process many AST nodes where the traversal overhead compounds
- **Basic tests** with fewer nodes show moderate improvements (9-20% faster)
- **Edge cases** with complex nesting benefit from the more frequent early-exit checks

The line profiler shows the optimization reduces time spent in `generic_visit` from 144.2ms to 107.9ms (25% improvement), with the overall `visit_Call` method improving from 287.5ms to 210.3ms. This optimization is particularly valuable for AST analysis tools that process large codebases, as the traversal overhead reduction scales with the size and complexity of the analyzed code.
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Nov 5, 2025
@aseembits93 aseembits93 merged commit 5ff60e2 into inspect-signature-issue Nov 5, 2025
21 of 23 checks passed
@codeflash-ai codeflash-ai bot deleted the codeflash/optimize-pr867-2025-11-05T08.40.37 branch November 5, 2025 09:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants